from datetime import datetime
from typing import List, Optional, Dict, Any
import random
import json
import os
import shutil
import asyncio

from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from app.core.logging import get_logger
from app.db.models.ratio_task import RatioInstance, RatioRelation
from app.db.models import Dataset, DatasetFiles
from app.db.session import AsyncSessionLocal
from app.module.dataset.schema.dataset_file import DatasetFileTag
from app.module.shared.schema import TaskStatus
from app.module.ratio.schema.ratio_task import FilterCondition

logger = get_logger(__name__)


class RatioTaskService:
    """Service for Ratio Task DB operations."""

    def __init__(self, db: AsyncSession):
        self.db = db

    async def create_task(
        self,
        *,
        name: str,
        description: Optional[str],
        totals: int,
        config: List[Dict[str, Any]],
        target_dataset_id: Optional[str] = None,
    ) -> RatioInstance:
        """Create a ratio task instance and its relations.

        config item format: {"dataset_id": str, "counts": int, "filter_conditions": str}
        """
        logger.info(f"Creating ratio task: name={name}, totals={totals}, items={len(config or [])}")

        instance = RatioInstance(
            name=name,
            description=description,
            totals=totals,
            target_dataset_id=target_dataset_id,
            status="PENDING",
        )
        self.db.add(instance)
        await self.db.flush()  # populate instance.id

        for item in config or []:
            relation = RatioRelation(
                ratio_instance_id=instance.id,
                source_dataset_id=item.get("dataset_id"),
                counts=int(item.get("counts", 0)),
                filter_conditions=json.dumps({
                    'date_range': item.get("filter_conditions").date_range,
                    'label': {
                        "label":item.get("filter_conditions").label.label,
                        "value":item.get("filter_conditions").label.value,
                    },
                })
            )
            logger.info(f"Relation created: {relation.id}, {relation}, {item}, {config}")
            self.db.add(relation)

        await self.db.commit()
        await self.db.refresh(instance)
        logger.info(f"Ratio task created: {instance.id}")
        return instance

    # ========================= Execution (Background) ========================= #

    @staticmethod
    async def execute_dataset_ratio_task(instance_id: str) -> None:
        """Execute a ratio task in background.

        Supported ratio_method:
        - DATASET: randomly select counts files from each source dataset
        - TAG: randomly select counts files matching relation.filter_conditions tags

        Steps:
        - Mark instance RUNNING
        - For each relation: fetch ACTIVE files, optionally filter by tags
        - Copy selected files into target dataset
        - Update dataset statistics and mark instance SUCCESS/FAILED
        """
        async with AsyncSessionLocal() as session:  # type: AsyncSession
            try:
                # Load instance and relations
                inst_res = await session.execute(select(RatioInstance).where(RatioInstance.id == instance_id))
                instance: Optional[RatioInstance] = inst_res.scalar_one_or_none()
                if not instance:
                    logger.error(f"Ratio instance not found: {instance_id}")
                    return
                logger.info(f"start execute ratio task: {instance_id}")

                rel_res = await session.execute(
                    select(RatioRelation).where(RatioRelation.ratio_instance_id == instance_id)
                )
                relations: List[RatioRelation] = list(rel_res.scalars().all())

                # Mark running
                instance.status = TaskStatus.RUNNING.name

                # Load target dataset
                ds_res = await session.execute(select(Dataset).where(Dataset.id == instance.target_dataset_id))
                target_ds: Optional[Dataset] = ds_res.scalar_one_or_none()
                if not target_ds:
                    logger.error(f"Target dataset not found for instance {instance_id}")
                    instance.status = TaskStatus.FAILED.name
                    return

                added_count, added_size = await RatioTaskService.handle_ratio_relations(relations,session, target_ds)

                # Update target dataset statistics
                target_ds.file_count = (target_ds.file_count or 0) + added_count  # type: ignore
                target_ds.size_bytes = (target_ds.size_bytes or 0) + added_size  # type: ignore
                # If target dataset has files, mark it ACTIVE
                if (target_ds.file_count or 0) > 0:  # type: ignore
                    target_ds.status = "ACTIVE"

                # Done
                instance.status = TaskStatus.COMPLETED.name
                logger.info(f"Dataset ratio execution completed: instance={instance_id}, files={added_count}, size={added_size}, {instance.status}")

            except Exception as e:
                logger.exception(f"Dataset ratio execution failed for {instance_id}: {e}")
                try:
                    # Try mark failed
                    inst_res = await session.execute(select(RatioInstance).where(RatioInstance.id == instance_id))
                    instance = inst_res.scalar_one_or_none()
                    if instance:
                        instance.status = TaskStatus.FAILED.name
                finally:
                    pass
            finally:
                await session.commit()

    @staticmethod
    async def handle_ratio_relations(relations: list[RatioRelation], session, target_ds: Dataset) -> tuple[int, int]:
        # Preload existing target file paths for deduplication
        existing_path_rows = await session.execute(
            select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id)
        )
        existing_paths = set(p for p in existing_path_rows.scalars().all() if p)

        added_count = 0
        added_size = 0

        for rel in relations:
            if not rel.source_dataset_id or not rel.counts or rel.counts <= 0:
                continue

            files = await RatioTaskService.get_files(rel, session)

            if not files:
                continue

            pick_n = min(rel.counts or 0, len(files))
            chosen = random.sample(files, pick_n) if pick_n < len(files) else files

            # Copy into target dataset with de-dup by target path
            for f in chosen:
                await RatioTaskService.handle_selected_file(existing_paths, f, session, target_ds)
                added_count += 1
                added_size += int(f.file_size or 0)

            # Periodically flush to avoid huge transactions
            await session.flush()
        return added_count, added_size

    @staticmethod
    async def handle_selected_file(existing_paths: set[Any], f, session, target_ds: Dataset):
        src_path = f.file_path
        dst_prefix = f"/dataset/{target_ds.id}/"
        file_name = RatioTaskService.get_new_file_name(dst_prefix, existing_paths, f)

        new_path = dst_prefix + file_name
        dst_dir = os.path.dirname(new_path)
        await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
        await asyncio.to_thread(shutil.copy2, src_path, new_path)

        file_data = {
            "dataset_id": target_ds.id,  # type: ignore
            "file_name": file_name,
            "file_path": new_path,
            "file_type": f.file_type,
            "file_size": f.file_size,
            "check_sum": f.check_sum,
            "tags": f.tags,
            "tags_updated_at": datetime.now(),
            "dataset_filemetadata": f.dataset_filemetadata,
            "status": "ACTIVE",
        }
        file_record = {k: v for k, v in file_data.items() if v is not None}
        session.add(DatasetFiles(**file_record))
        existing_paths.add(new_path)

    @staticmethod
    def get_new_file_name(dst_prefix: str, existing_paths: set[Any], f) -> str:
        file_name = f.file_name
        new_path = dst_prefix + file_name

        # Handle file path conflicts by appending a number to the filename
        if new_path in existing_paths:
            file_name_base, file_ext = os.path.splitext(file_name)
            counter = 1
            original_file_name = file_name
            while new_path in existing_paths:
                file_name = f"{file_name_base}_{counter}{file_ext}"
                new_path = f"{dst_prefix}{file_name}"
                counter += 1
                if counter > 1000:  # Safety check to prevent infinite loops
                    logger.error(f"Could not find unique filename for {original_file_name} after 1000 attempts")
                    break
        return file_name

    @staticmethod
    async def get_files(rel: RatioRelation, session) -> list[Any]:
        # Fetch all files for the source dataset (ACTIVE only)
        files_res = await session.execute(
            select(DatasetFiles).where(
                DatasetFiles.dataset_id == rel.source_dataset_id,
                DatasetFiles.status == "ACTIVE",
            )
        )
        files = list(files_res.scalars().all())

        # TAG mode: filter by tags according to relation.filter_conditions
        conditions = RatioTaskService._parse_conditions(rel.filter_conditions)
        if conditions:
            files = [f for f in files if RatioTaskService._filter_file(f, conditions)]
        return files

    # ------------------------- helpers for TAG filtering ------------------------- #

    @staticmethod
    def _parse_conditions(conditions: Optional[str]) -> Optional[FilterCondition]:
        """Parse filter_conditions JSON string into a FilterCondition object.

        Args:
            conditions: JSON string containing filter conditions

        Returns:
            FilterCondition object if conditions is not None/empty, otherwise None
        """
        if not conditions:
            return None
        try:
            data = json.loads(conditions)
            return FilterCondition(**data)
        except json.JSONDecodeError as e:
            logger.error(f"Failed to parse filter conditions: {e}")
            return None
        except Exception as e:
            logger.error(f"Error creating FilterCondition: {e}")
            return None

    @staticmethod
    def _filter_file(file: DatasetFiles, conditions: FilterCondition) -> bool:
        if not conditions:
            return True
        logger.info(f"start filter file: {file}, conditions: {conditions}")

        # Check data range condition if provided
        if conditions.date_range:
            try:
                from datetime import datetime, timedelta
                data_range_days = int(conditions.date_range)
                if data_range_days > 0:
                    cutoff_date = datetime.now() - timedelta(days=data_range_days)
                    if file.tags_updated_at and file.tags_updated_at < cutoff_date:
                        return False
            except (ValueError, TypeError) as e:
                logger.warning(f"Invalid data_range value: {conditions.date_range}", e)
                return False

        # Check label condition if provided
        if conditions.label:
            tags = file.tags
            if not tags:
                return False
            try:
                # tags could be a list of strings or list of objects with 'name'
                tag_names = RatioTaskService.get_all_tags(tags)
                return f"{conditions.label.label}@{conditions.label.value}" in tag_names
            except Exception as e:
                logger.exception(f"Failed to get tags for {file}", e)
                return False

        return True

    @staticmethod
    def get_all_tags(tags) -> set[str]:
        """获取所有处理后的标签字符串列表"""
        all_tags = set()
        if not tags:
            return all_tags

        file_tags = []
        for tag_data in tags:
            # 处理可能的命名风格转换（下划线转驼峰）
            processed_data = {}
            for key, value in tag_data.items():
                # 将驼峰转为下划线以匹配 Pydantic 模型字段
                processed_data[key] = value
            # 创建 DatasetFileTag 对象
            file_tag = DatasetFileTag(**processed_data)
            file_tags.append(file_tag)

        for file_tag in file_tags:
            for tag_data in file_tag.get_tags():
                all_tags.add(tag_data)
        return all_tags
